from model.SNUNet import SNUNet_ECAM
import torch

model = SNUNet_ECAM(4, 2)

x = torch.randn((2, 4, 512, 512))
x2 = torch.randn((2, 4, 512, 512))

x = model(x, x2)

print(x[0].shape)
